Skip to content

Add state_dict converter for DeepSeekv3 in torchtitan #1538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 12, 2025
Merged

Conversation

wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Aug 6, 2025

Support loading a DeepSeek HF weights to Deepseek-V3 model:

  1. Support split / concat weight for GroupedExperts
  2. Support _dequantization during loading HF checkpoints

Numerical verification: (using offline conversion script)

python convert_from_hf.py /data/users/jianiw/dsv3-weights outputs/checkpoint-dsv3-cpu --model_name deepseek_v3 --model_flavor 671B > cpu_convert.txt 2>&1
Screenshot 2025-08-11 at 4 31 50 PM Screenshot 2025-08-11 at 4 32 23 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 6, 2025
@wwwjn wwwjn changed the title [WIP] Add state_dict converter for DeepSeekv3 in torchtitan Add state_dict converter for DeepSeekv3 in torchtitan Aug 11, 2025
@wwwjn wwwjn marked this pull request as ready for review August 11, 2025 23:33
@wwwjn wwwjn requested review from ankitade and ebsmothers August 11, 2025 23:33
@@ -16,12 +16,12 @@
from tokenizer.tiktoken import BaseTokenizer, IGNORE_INDEX
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from transform import CLIPTransform
from utils import load_image
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: This is because of I ran pre-commit

@@ -282,10 +282,12 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
self.register_buffer(
"expert_bias",
torch.zeros(num_experts, dtype=torch.float32),
persistent=True,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: Explicitly add whether the registered buffer is persistent. When false, we are not expected to load from DCP checkpoint.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe need to rebase onto #1526 after it lands.

@wwwjn wwwjn requested a review from tianyu-l August 12, 2025 06:51
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the conversion

  1. can be used to offline convert HF checkpoint from fp8 to fp32 using CPU plain tensor.
  2. can't be used to convert HF checkpoint on the fly using GPU DTensor, because of sharding and quantized blocks may not be aligned well.
  3. can't be used for weight sync to generate a state dict of bf16 because fake quantization to fp8 is applied.

I think it's OK to land this PR to unblock 1, but better to explain things clearly somewhere.

I also had some inline comments.

@wwwjn wwwjn requested a review from tianyu-l August 12, 2025 18:10
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM to unblock

@wwwjn wwwjn merged commit a6972ae into main Aug 12, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants